Skip to content

[PyTorch] Support single parameter for GroupedLinear#2731

Merged
ksivaman merged 6 commits intoNVIDIA:mainfrom
ksivaman:single_param_grouped_weight
Mar 4, 2026
Merged

[PyTorch] Support single parameter for GroupedLinear#2731
ksivaman merged 6 commits intoNVIDIA:mainfrom
ksivaman:single_param_grouped_weight

Conversation

@ksivaman
Copy link
Member

@ksivaman ksivaman commented Mar 4, 2026

Description

Support option to expose single parameter for GroupedLinear module.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Support option to expose single parameter for GroupedLinear module.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from zhongbozhu March 4, 2026 03:09
@ksivaman
Copy link
Member Author

ksivaman commented Mar 4, 2026

/te-ci pytorch L0

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR introduces a single_grouped_parameter=True option for GroupedLinear that consolidates all per-GEMM weight parameters into a single nn.Parameter backed by a new GroupedTensor wrapper subclass (a torch.Tensor subclass layered on top of the renamed GroupedTensorStorage). The C++ quantizer layer is updated to instantiate either GroupedTensor or GroupedTensorStorage depending on quantizer.internal, and a new __torch_dispatch__ protocol dequantizes members into a stacked tensor, performs the requested op, and writes results back in-place.

Core changes:

  • New GroupedTensor class (grouped_tensor.py): inherits from both GroupedTensorStorage and torch.Tensor via _make_wrapper_subclass; handles shape-manipulation bans, in-place and default dispatch paths.
  • GroupedTensorStorage rename (grouped_tensor_storage.py): grouped_tensor.pygrouped_tensor_storage.py; GroupedTensorGroupedTensorStorage; datarowwise_data; shapetensor_shapes; logical_shape promoted to required positional argument.
  • grouped_linear.py: adds single_grouped_parameter constructor flag, make_grouped_weights() that registers one "weight" parameter and nulls out individual "weight{i}" slots, and _get_weight_tensors() adaption to either path.
  • C++ quantizer layer (quantizer.cpp, pybind.cpp, pybind.h): all six create_grouped_tensor methods switched from pybind11 named-arg syntax to PyObject_Call with a kwargs dict, dispatching to GroupedTensorPythonClass or GroupedTensorStoragePythonClass.

Minor issue found: A duplicate comment line in make_grouped_weights at consecutive lines.

The core refactoring of the storage class, API cleanup, and C++ dispatch update is sound, with appropriate assertion guards against incompatible quantizer configurations.

Confidence Score: 3/5

  • Minor style issue identified; core refactoring appears sound for the experimental single_grouped_parameter feature.
  • Only one minor style comment found in this filtered review (duplicate comment line), which is easily correctable. The refactoring of the storage class, API updates, and C++ dispatch changes are well-structured. However, this is still an experimental new feature being added to the codebase, so a moderate confidence is appropriate.
  • transformer_engine/pytorch/module/grouped_linear.py (remove duplicate comment line)

Sequence Diagram

sequenceDiagram
    participant GL as GroupedLinear.__init__
    participant RP as reset_parameters
    participant MGW as make_grouped_weights
    participant GTS as GroupedTensorStorage.make_grouped_tensor
    participant GT as GroupedTensor (wrapper)
    participant GTS2 as GroupedTensorStorage (internal)

    GL->>RP: single_grouped_parameter=True
    RP->>MGW: make_grouped_weights(defer_init)
    MGW->>GTS: make_grouped_tensor_with_shapes(shapes, quantizer)
    alt quantizer.internal == False (or quantizer is None)
        GTS->>GT: GroupedTensor(shape, dtype, num_tensors, ...)
        GT-->>MGW: grouped_weights (torch.Tensor subclass)
        MGW->>GL: register_parameter("weight", nn.Parameter(grouped_weights))
    else quantizer.internal == True
        GTS->>GTS2: GroupedTensorStorage(shape, dtype, num_tensors, ...)
        GTS2-->>MGW: grouped_weights (plain Python object)
        MGW->>GL: assert fails → error raised
    end
    MGW->>GL: register_parameter("weight{i}", None) for each GEMM

    note over GT: GroupedTensor.__torch_dispatch__ intercepts ops,<br/>dequantizes members → stacked tensor → op → requantizes
Loading

Last reviewed commit: ef58675

Comment on lines +789 to +796
# Re-register as a single grouped weight parameter.
self.register_parameter(
"weight",
torch.nn.Parameter(grouped_weights),
init_fn=self.init_method,
get_rng_state_tracker=self.get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.Parameter wrapping will crash if quantizer.internal=True

GroupedTensorStorage.make_grouped_tensor() branches on quantizer.internal: when True it returns a plain GroupedTensorStorage, which is not a torch.Tensor subclass. Passing that object to torch.nn.Parameter(...) will raise a TypeError at runtime.

The early-return guard at line 766 covers delayed() and float8_current_scaling(), but not MXFP8, Float8BlockScaling, or NVFP4 with an internal quantizer. If any of those quantizers are used as weight quantizers with internal=True, this line will crash:

self.register_parameter("weight", torch.nn.Parameter(grouped_weights), ...)

Consider adding an explicit assertion before this call:

Suggested change
# Re-register as a single grouped weight parameter.
self.register_parameter(
"weight",
torch.nn.Parameter(grouped_weights),
init_fn=self.init_method,
get_rng_state_tracker=self.get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"],
)
# Re-register as a single grouped weight parameter.
assert isinstance(grouped_weights, torch.Tensor), (
"single_grouped_parameter requires a GroupedTensor (torch.Tensor subclass); "
"got GroupedTensorStorage (quantizer.internal=True is unsupported here)."
)
self.register_parameter(
"weight",
torch.nn.Parameter(grouped_weights),
init_fn=self.init_method,
get_rng_state_tracker=self.get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"],
)

Comment on lines 266 to 274
def __repr__(self) -> str:
"""String representation of the GroupedTensor."""
return (
f"GroupedTensor(num_tensors={self.num_tensors}, "
f"shape={self.shape}, "
f"shapes={self.tensor_shapes}, "
f"logical_shape={self.logical_shape}, "
f"quantizer={self.quantizer}, "
f"dtype={self.get_dtype()})"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale class name in __repr__ output

After the rename of this class from GroupedTensor to GroupedTensorStorage, the repr string still emits GroupedTensor(...). This makes it confusing to distinguish the storage object from the new GroupedTensor wrapper when debugging.

Suggested change
def __repr__(self) -> str:
"""String representation of the GroupedTensor."""
return (
f"GroupedTensor(num_tensors={self.num_tensors}, "
f"shape={self.shape}, "
f"shapes={self.tensor_shapes}, "
f"logical_shape={self.logical_shape}, "
f"quantizer={self.quantizer}, "
f"dtype={self.get_dtype()})"
)
def __repr__(self) -> str:
"""String representation of the GroupedTensorStorage."""
return (
f"GroupedTensorStorage(num_tensors={self.num_tensors}, "
f"shapes={self.tensor_shapes}, "
f"logical_shape={self.logical_shape}, "
f"quantizer={self.quantizer}, "
f"dtype={self.get_dtype()})"
)

Comment on lines 277 to 296
@@ -314,20 +296,20 @@ def make_grouped_tensor_with_shapes(
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type annotation is too narrow

make_grouped_tensor_with_shapes() (and make_grouped_tensor() at line 333) are annotated as returning GroupedTensorStorage, but they actually return a GroupedTensor (a torch.Tensor subclass) when quantizer.internal is False — which is the common case for user-facing weight parameters.

Looking at lines 564–569 of make_grouped_tensor():

internal = False if quantizer is None else quantizer.internal
if internal:
    grouped_tensor_class = GroupedTensorStorage
else:
    from ..grouped_tensor import GroupedTensor
    grouped_tensor_class = GroupedTensor

Callers in grouped_linear.py wrap the return value in torch.nn.Parameter, which only works for torch.Tensor subclasses. The annotation does not convey this requirement and will mislead type-checkers.

Consider updating the return type annotation to Union[GroupedTensorStorage, GroupedTensor] or adding a note to the docstring clarifying that the returned type may be a GroupedTensor when quantizer.internal=False.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines 818 to 823
set_tensor_model_parallel_attributes(
tensor=getattr(self, f"weight{i}"),
tensor=grouped_weight,
is_parallel=True,
dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong partition_dim for 3D grouped weight tensor

When single_grouped_parameter=True, the grouped weight has shape [num_gemms, out_features, in_features]. The same dim values used for the per-GEMM 2D weights (out_features, in_features) are reused here without adjustment:

  • "row" parallel → dim=1 partitions along out_features — but it should partition along in_featuresdim=2
  • "column" parallel → dim=0 partitions along num_gemms — but it should partition along out_featuresdim=1

This causes the wrong axis to be sharded, breaking any tensor-parallel run that uses single_grouped_parameter=True.

Suggested change
set_tensor_model_parallel_attributes(
tensor=getattr(self, f"weight{i}"),
tensor=grouped_weight,
is_parallel=True,
dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
set_tensor_model_parallel_attributes(
tensor=grouped_weight,
is_parallel=True,
dim=2 if self.parallel_mode == "row" else 1,
stride=1,
)

Comment on lines +187 to +198
super().__torch_dispatch__(func, types, new_args, new_kwargs)
for arg, new_arg, schema_arg in zip(args, new_args, schema_args):
maybe_update_inplace(arg, new_arg, schema_arg)
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
assert kwarg == new_kwarg == schema_arg.name, "name of kwarg should match schema"
maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg)
return None

# Default op: operate on dequantized stacked tensors.
new_args = tree_map(maybe_unwrap, args)
new_kwargs = tree_map(maybe_unwrap, kwargs)
return super().__torch_dispatch__(func, types, new_args, new_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super().__torch_dispatch__ passes original types containing GroupedTensor

Both the in-place and default dispatch paths call:

super().__torch_dispatch__(func, types, new_args, new_kwargs)

At this point new_args has already been unwrapped (all GroupedTensor instances replaced with plain stacked tensors), but types still contains GroupedTensor. PyTorch's C++ dispatch layer examines types when deciding whether to re-dispatch; passing the original types while the actual tensor arguments are plain tensors can cause the dispatch to call GroupedTensor.__torch_dispatch__ again, leading to infinite recursion.

The idiomatic pattern for a wrapper subclass that has already substituted all its arguments is to call the op directly:

# In-place path:
func(*new_args, **new_kwargs)

# Default path:
return func(*new_args, **new_kwargs)

This avoids any re-dispatch via types and directly executes the ATen kernel on the unwrapped tensors.

Comment on lines +190 to +192
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
assert kwarg == new_kwarg == schema_arg.name, "name of kwarg should match schema"
maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kwargs-to-schema alignment is fragile

The loop

for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
    assert kwarg == new_kwarg == schema_arg.name, "name of kwarg should match schema"

makes two implicit assumptions that can break:

  1. Order: kwargs (and new_kwargs) enumerate their keys in insertion order, but schema_args[args_len:] lists all remaining schema arguments, including any that were not actually passed as kwargs. If the caller omits an optional argument that appears before the first actual kwarg in the schema, the zip would pair them incorrectly, triggering the assertion.

  2. Coverage: zip silently stops at the shortest iterable, so in-place writeback for kwargs that appear later in the schema than the number of passed kwargs is silently skipped.

A safer approach is to match kwargs by name against the schema:

schema_arg_by_name = {a.name: a for a in schema_args[args_len:]}
for kwarg in kwargs:
    schema_arg = schema_arg_by_name.get(kwarg)
    if schema_arg is not None:
        maybe_update_inplace(kwargs[kwarg], new_kwargs[kwarg], schema_arg)

return arg

def update_grouped_tensor_inplace(grouped: GroupedTensor, updated: torch.Tensor):
if not grouped.all_same_shape():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is also repeated in grouped_to_stacked_tensor. might make sense to move this to the start of function and do it once?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are 2 separate checks but the execution only happens once, either for the inplace case, or the normal case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I was saying from code duplication perspective. Anyway this is minor. Will leave it upto you

PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
std::once_flag extension_init_flag;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag is used below to initialize all extensions once

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, just moved 2 lines below


strides = [1] * len(wrapper_shape)
for i in range(len(wrapper_shape) - 2, -1, -1):
strides[i] = strides[i + 1] * wrapper_shape[i + 1]
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other classes, I have recently added stride as an argument to constructor.

This allows to create class from C++ with lesser CPU overhead. Any python compute code in new has a lot of CPU overhead.

This logic can be done only if stride is not provided (whih will be object creation in python)

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines +122 to +124
# Parameter construction calls detach()/alias-like paths.
if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default):
return args[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detach short-circuit returns the original GroupedTensor without detaching

Returning args[0] directly for detach.default and alias.default means the returned tensor is the same object as the original — it is not a new detached view. This is intentional for wrapper subclasses in some cases, but it has a subtle consequence: torch.nn.Parameter(grouped_weights) internally calls detach() and expects a logically detached tensor. Returning the original GroupedTensor means there is only one object being both the raw tensor and the nn.Parameter backing storage, which can confuse gradient tracking and .data access patterns.

A more conventional pattern for wrapper subclasses is to create a shallow clone via _make_wrapper_subclass to represent the detached view, preserving the same storage but returning a distinct object:

if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default):
    out = args[0].__class__.__new__(args[0].__class__, ...)
    # copy over GroupedTensorStorage state fields
    return out

At minimum, consider documenting why returning args[0] directly is safe here.

ksivaman and others added 3 commits March 4, 2026 10:19
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
if grouped.quantizer is None:
grouped_members = grouped.quantized_tensors
if grouped_members is None:
grouped_members = grouped.split_into_quantized_tensors()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some fall back logic here:

  • if grouped quantize is available
  • if it's not available, split it and trigger quantize one by one

)
grouped_members = grouped.quantized_tensors
if grouped_members is None:
grouped_members = grouped.split_into_quantized_tensors()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment, add some fallback check

Comment on lines +789 to +790
# Re-register as a single grouped weight parameter.
# Re-register as a single grouped weight parameter.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate comment line

The comment # Re-register as a single grouped weight parameter. appears on both line 789 and 790 — this is a copy-paste artifact. Remove one of them.

Suggested change
# Re-register as a single grouped weight parameter.
# Re-register as a single grouped weight parameter.
# Re-register as a single grouped weight parameter.
assert isinstance(grouped_weights, torch.Tensor) and (

Copy link
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, need to add some logic when using split_into_quantized_tensors since this API is not going to be performant, we shouldn't need to split and then quantize when a grouped quantize kernel for weight is ready.

@ksivaman
Copy link
Member Author

ksivaman commented Mar 4, 2026

/te-ci pytorch L0

def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
grouped_weight = getattr(self, "weight", None)
Copy link
Collaborator

@zhongbozhu zhongbozhu Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is another case:

  • suppose fp8 primary weight is already supported for grouped single weight, then you also shouldn't split it because this means that you can directly call the gemm
  • suppose we didn't turn on fp8 primary weight, but still turn on single weight, then we should 1) if grouped quantize kernel for weight is spported, call it 2) if it's not, then call this API to split and quantize it, so it's actually better to call a split_quantize instead which is more performant.

This basically means that when this _get_weight_tensors is getting called, it's not going to try to split it beforehand, but we should instead try to call grouped_quantize or split_quantize here:

@ksivaman ksivaman merged commit bf3201a into NVIDIA:main Mar 4, 2026
9 of 14 checks passed
@ksivaman ksivaman deleted the single_param_grouped_weight branch March 4, 2026 06:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants